# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for mcore MLP UT of inference."""
import numpy as np

BATCH_SIZE = 2
SEQ_LENGTH = 2
INPUT_SIZE = 32
FFN_HIDDEN_SIZE = 32


def get_init_params(input_size, ffn_hidden_size):
    """Generate initialization parameters"""
    np.random.seed(2025)
    fc1_add_gate_weight_shape = (ffn_hidden_size * 2, input_size)
    fc1_no_gate_weight_shape = (ffn_hidden_size, input_size)
    fc2_weight_shape = (input_size, ffn_hidden_size)
    return {
        "input": np.random.rand(BATCH_SIZE * SEQ_LENGTH, input_size),
        "fc1_gate_weight": 0.01 * np.random.rand(*fc1_add_gate_weight_shape),
        "fc1_no_gate_weight": 0.01 * np.random.rand(*fc1_no_gate_weight_shape),
        "fc2_weight": 0.01 * np.random.rand(*fc2_weight_shape),
        "fc1_gate_bias": 0.01 * np.random.rand(ffn_hidden_size * 2),
        "fc1_no_gate_bias": 0.01 * np.random.rand(ffn_hidden_size),
        "fc2_bias": 0.01 * np.random.rand(input_size)
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    has_bias_gate_output = np.array(
        [[0.00838303, 0.00728844, 0.00262961, 0.00642217, 0.00839234,
          0.00784265, 0.00350671, 0.00696487, 0.0103511, 0.00516138,
          0.00440877, 0.00787543, 0.00644731, 0.00560005, 0.00720802,
          0.00887262, 0.00078947, 0.00684218, 0.00749716, 0.00872077,
          0.00178315, 0.00804651, 0.00704303, 0.00415782, 0.00103956,
          0.00662815, 0.0092494, 0.00880827, 0.00841713, 0.00823413,
          0.00263828, 0.00315325],
         [0.00844887, 0.00736587, 0.00268006, 0.00649034, 0.00846487,
          0.00790741, 0.00358242, 0.00702974, 0.01041883, 0.00523255,
          0.00447889, 0.00791309, 0.00652121, 0.0056381, 0.00727176,
          0.00894632, 0.0008572, 0.00689829, 0.00757435, 0.00877497,
          0.00186433, 0.0081053, 0.00710625, 0.00421552, 0.00109083,
          0.00668698, 0.00929842, 0.00886845, 0.0084743, 0.00831737,
          0.00270922, 0.00320634],
         [0.0081311, 0.00703599, 0.00243001, 0.00623019, 0.00810873,
          0.00762472, 0.00327798, 0.00674674, 0.01011598, 0.00490351,
          0.00416127, 0.00764442, 0.00617755, 0.00541439, 0.00696468,
          0.00863892, 0.00053194, 0.0066292, 0.00724788, 0.00849597,
          0.00154071, 0.00779932, 0.00685628, 0.00391368, 0.00080307,
          0.00639863, 0.0090129, 0.00859149, 0.00824183, 0.00795795,
          0.00241081, 0.00288942],
         [0.00817706, 0.00708824, 0.00246002, 0.00627149, 0.0081659,
          0.00765846, 0.00331535, 0.00677396, 0.01014467, 0.00495977,
          0.00420982, 0.00769886, 0.00622435, 0.00543528, 0.0070222,
          0.00868806, 0.00057166, 0.00666396, 0.00728669, 0.00855382,
          0.00159172, 0.00784851, 0.00689046, 0.00396291, 0.00085417,
          0.00643766, 0.0090601, 0.00862996, 0.00827296, 0.00800964,
          0.00244809, 0.00294001]], dtype=np.float32
    )

    has_gate_output = np.array(
        [[0.00061804, 0.00065196, 0.00050569, 0.00049526, 0.00069376,
          0.00056426, 0.00058755, 0.00055758, 0.0005956, 0.00063853,
          0.00062807, 0.00055279, 0.00069377, 0.00044777, 0.00059069,
          0.0005746, 0.00064622, 0.00052796, 0.00063937, 0.00055674,
          0.00060809, 0.00061874, 0.00047239, 0.00061491, 0.00058301,
          0.00059383, 0.00056958, 0.00054375, 0.00044742, 0.00068555,
          0.00056984, 0.00063744],
         [0.0006803, 0.00072538, 0.00055413, 0.0005597, 0.00076234,
          0.00062602, 0.00065949, 0.00061902, 0.00066035, 0.00070633,
          0.00069508, 0.00058883, 0.00076421, 0.00048416, 0.00065059,
          0.00064419, 0.0007109, 0.00058118, 0.00071296, 0.00060806,
          0.000685, 0.00067525, 0.00053212, 0.00067018, 0.0006318,
          0.00064979, 0.00061604, 0.00060083, 0.00050197, 0.00076439,
          0.0006372, 0.00068829],
         [0.00038521, 0.00041914, 0.00032165, 0.00031818, 0.00043188,
          0.00036254, 0.00037602, 0.00035669, 0.0003776, 0.00040089,
          0.00039987, 0.00034072, 0.00044402, 0.0002761, 0.00036627,
          0.00035907, 0.00040768, 0.00033118, 0.00040878, 0.00034938,
          0.00038479, 0.00039083, 0.00029916, 0.00038956, 0.00036434,
          0.00038175, 0.00035125, 0.00034426, 0.00028611, 0.00043044,
          0.00035928, 0.00039418],
         [0.00042687, 0.00046613, 0.00034829, 0.0003554, 0.00048334,
          0.00039317, 0.00040956, 0.00038099, 0.00040325, 0.00045129,
          0.00044313, 0.00038982, 0.00048638, 0.00029435, 0.0004183,
          0.00040335, 0.00044355, 0.00036277, 0.00044361, 0.00040204,
          0.00043064, 0.00043518, 0.00033022, 0.00043412, 0.00041079,
          0.00041678, 0.00039387, 0.00037871, 0.0003136, 0.00047718,
          0.00039278, 0.00043938]], dtype=np.float32
    )

    has_bias_output = np.array(
        [[0.01555738, 0.01518457, 0.00867038, 0.01247342, 0.01659538,
          0.01438225, 0.0107496, 0.01371755, 0.01747923, 0.0127057,
          0.01184535, 0.01427048, 0.01483543, 0.01077205, 0.0142104,
          0.01554536, 0.00853903, 0.01315831, 0.0153219, 0.01523462,
          0.009009, 0.0152718, 0.01262433, 0.01141226, 0.0077691,
          0.01361223, 0.01586456, 0.01541413, 0.01402102, 0.01630843,
          0.00962818, 0.01068564],
         [0.01599199, 0.01559386, 0.00905177, 0.01281861, 0.01716332,
          0.01476059, 0.01119076, 0.01421434, 0.01788127, 0.01314647,
          0.01234917, 0.01457499, 0.01530058, 0.01118175, 0.01470033,
          0.01601679, 0.00892377, 0.01360942, 0.01579724, 0.01563667,
          0.00949631, 0.01573741, 0.01297542, 0.01176408, 0.00815579,
          0.01403268, 0.01636559, 0.01578651, 0.01442906, 0.01672856,
          0.01004083, 0.0112652],
         [0.01405674, 0.01355876, 0.00739947, 0.01123314, 0.01490179,
          0.01298387, 0.00919183, 0.0123187, 0.01590243, 0.01118586,
          0.01034639, 0.01286399, 0.01306041, 0.00975238, 0.0127456,
          0.01415392, 0.00689243, 0.0118171, 0.01363012, 0.01383892,
          0.00748824, 0.01373825, 0.01142307, 0.00985883, 0.00636299,
          0.01218575, 0.0145352, 0.01401662, 0.01292906, 0.01454648,
          0.00812417, 0.00913476],
         [0.01410933, 0.0137259, 0.00751418, 0.01140136, 0.01507689,
          0.01311683, 0.00935614, 0.01254476, 0.01600985, 0.01125391,
          0.01050307, 0.01290799, 0.01314826, 0.0097773, 0.01280745,
          0.01431048, 0.00696046, 0.01198903, 0.01369895, 0.01393838,
          0.00755944, 0.01399155, 0.0114097, 0.00991082, 0.00644973,
          0.01235832, 0.01464044, 0.01404685, 0.01305206, 0.01471789,
          0.00823076, 0.00919707]], dtype=np.float32
    )

    output_only = np.array(
        [[0.00737803, 0.00809012, 0.00621319, 0.00620454, 0.00844559,
          0.00676225, 0.00740756, 0.00693321, 0.00736123, 0.00777809,
          0.00763244, 0.00657675, 0.00859106, 0.00534837, 0.00720232,
          0.00688425, 0.00797988, 0.00651171, 0.0080434, 0.00670953,
          0.00744726, 0.00745602, 0.00573933, 0.00746953, 0.00696929,
          0.00715746, 0.00681637, 0.00676596, 0.00574044, 0.00830984,
          0.00718772, 0.00776022],
         [0.00781033, 0.00849738, 0.00659293, 0.00654814, 0.00901102,
          0.00713884, 0.0078464, 0.00742774, 0.00776147, 0.0082169,
          0.00813386, 0.00687948, 0.00905363, 0.00575626, 0.00769001,
          0.00735352, 0.00836283, 0.00696077, 0.00851658, 0.00710974,
          0.00793248, 0.0079195, 0.00608883, 0.00781945, 0.00735434,
          0.00757576, 0.00731482, 0.00713629, 0.00614664, 0.00872788,
          0.0075985, 0.00833716],
         [0.00588491, 0.00647234, 0.00494847, 0.00497024, 0.00676003,
          0.00537017, 0.00585757, 0.00554111, 0.00579146, 0.0062653,
          0.00614091, 0.00517719, 0.00682484, 0.00433333, 0.00574468,
          0.00549915, 0.00634128, 0.00517667, 0.00635959, 0.00532055,
          0.0059334, 0.00592948, 0.00454398, 0.00592351, 0.0055695,
          0.00573821, 0.00549325, 0.00537539, 0.00465375, 0.00655622,
          0.00569081, 0.00621665],
         [0.00593735, 0.0066389, 0.00506305, 0.00513776, 0.00693472,
          0.00550281, 0.00602128, 0.00576641, 0.00589877, 0.00633345,
          0.0062971, 0.0052212, 0.00691249, 0.00435843, 0.00580637,
          0.00565526, 0.00640933, 0.00534811, 0.00642853, 0.00541962,
          0.00600462, 0.00618174, 0.00453085, 0.00597541, 0.005656,
          0.00591011, 0.00559812, 0.00540579, 0.00477643, 0.00672697,
          0.0057971, 0.00627912]], dtype=np.float32
    )

    return {
        "has_bias_gate_output": has_bias_gate_output,
        "has_gate_output": has_gate_output,
        "has_bias_output": has_bias_output,
        "output_only": output_only,
    }


def get_gpu_data() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    has_bias_gate_output = np.array(
        [[0.00842, 0.007263, 0.002625, 0.00644, 0.00842,
          0.00787, 0.00351, 0.006958, 0.010376, 0.005157,
          0.004395, 0.00787, 0.00644, 0.005615, 0.007202,
          0.00891, 0.0007896, 0.006836, 0.007507, 0.00867,
          0.001778, 0.00806, 0.00705, 0.00415, 0.001038,
          0.006622, 0.00928, 0.00879, 0.00842, 0.00824,
          0.00264, 0.003159],
         [0.008484, 0.007355, 0.00267, 0.0065, 0.008484,
          0.007935, 0.003586, 0.00702, 0.01044, 0.00522,
          0.004486, 0.007935, 0.0065, 0.005646, 0.007263,
          0.00897, 0.0008583, 0.006897, 0.00757, 0.00873,
          0.001862, 0.00812, 0.00711, 0.00421, 0.001091,
          0.006683, 0.00928, 0.00885, 0.008484, 0.0083,
          0.0027, 0.003204],
         [0.00812, 0.00702, 0.002426, 0.006226, 0.00812,
          0.00763, 0.00328, 0.006744, 0.01013, 0.004883,
          0.00415, 0.00763, 0.006165, 0.005432, 0.006958,
          0.00867, 0.000534, 0.006622, 0.007263, 0.008484,
          0.001541, 0.007812, 0.006836, 0.003906, 0.000805,
          0.00641, 0.00903, 0.008606, 0.0083, 0.007935,
          0.00241, 0.002884],
         [0.00818, 0.00708, 0.002457, 0.006287, 0.00818,
          0.00766, 0.003311, 0.006775, 0.01019, 0.004944,
          0.00421, 0.00769, 0.006226, 0.005432, 0.00702,
          0.00873, 0.000572, 0.006653, 0.007294, 0.008545,
          0.001587, 0.00787, 0.006897, 0.003967, 0.0008545,
          0.00644, 0.009094, 0.008606, 0.0083, 0.007996,
          0.002441, 0.002945]], dtype=np.float16
    )

    has_gate_output = np.array(
        [[0.000618, 0.0006523, 0.0005035, 0.000496, 0.0006943,
          0.0005646, 0.0005875, 0.000557, 0.000595, 0.000637,
          0.0006256, 0.000553, 0.0006943, 0.0004482, 0.0005913,
          0.000576, 0.0006447, 0.0005264, 0.000637, 0.000557,
          0.0006065, 0.000618, 0.000471, 0.000614, 0.0005836,
          0.000595, 0.0005684, 0.0005417, 0.0004463, 0.0006866,
          0.0005684, 0.000637],
         [0.000679, 0.000725, 0.000553, 0.0005608, 0.000763,
          0.0006256, 0.00066, 0.000618, 0.00066, 0.0007057,
          0.0006943, 0.0005875, 0.000763, 0.0004845, 0.0006485,
          0.0006447, 0.0007095, 0.00058, 0.0007133, 0.0006065,
          0.000683, 0.000675, 0.0005302, 0.0006714, 0.0006294,
          0.0006485, 0.000618, 0.0006027, 0.0004997, 0.000763,
          0.000637, 0.0006866],
         [0.0003853, 0.0004196, 0.0003223, 0.0003185, 0.000431,
          0.0003624, 0.0003757, 0.0003567, 0.0003777, 0.0004005,
          0.0003986, 0.0003414, 0.0004444, 0.0002766, 0.0003662,
          0.0003586, 0.0004082, 0.0003319, 0.0004082, 0.000349,
          0.0003834, 0.000391, 0.0002995, 0.000389, 0.0003643,
          0.0003815, 0.000351, 0.0003433, 0.000286, 0.000431,
          0.0003586, 0.0003948],
         [0.0004253, 0.0004654, 0.0003471, 0.0003548, 0.0004826,
          0.000393, 0.0004082, 0.0003815, 0.0004025, 0.0004501,
          0.0004425, 0.000389, 0.0004864, 0.0002937, 0.0004177,
          0.0004025, 0.0004425, 0.0003624, 0.0004425, 0.0004005,
          0.0004292, 0.0004349, 0.00033, 0.0004349, 0.00041,
          0.0004158, 0.000393, 0.0003777, 0.0003128, 0.0004768,
          0.000393, 0.0004387]], dtype=np.float16
    )

    has_bias_output = np.array(
        [[0.015564, 0.0152, 0.00867, 0.01245, 0.0166, 0.014404,
          0.01074, 0.01373, 0.01746, 0.012695, 0.01184, 0.01428,
          0.01483, 0.0108, 0.01422, 0.015564, 0.008545, 0.01312,
          0.01532, 0.0152, 0.00897, 0.01526, 0.012634, 0.01141,
          0.007782, 0.01361, 0.01587, 0.01538, 0.01404, 0.01636,
          0.00964, 0.01068],
         [0.01599, 0.015564, 0.00903, 0.01282, 0.01721, 0.01477,
          0.01117, 0.01422, 0.01794, 0.01312, 0.01233, 0.01459,
          0.01526, 0.01117, 0.01471, 0.01599, 0.00891, 0.01361,
          0.01575, 0.01563, 0.00946, 0.01575, 0.01294, 0.01178,
          0.00818, 0.01404, 0.01636, 0.01575, 0.014465, 0.01672,
          0.01001, 0.01129],
         [0.0141, 0.01355, 0.007385, 0.01123, 0.01489, 0.013,
          0.00922, 0.01233, 0.01599, 0.01117, 0.010376, 0.01288,
          0.01306, 0.009766, 0.01276, 0.01416, 0.006897, 0.01178,
          0.01361, 0.013794, 0.007477, 0.01373, 0.01141, 0.00983,
          0.006348, 0.01221, 0.01453, 0.01404, 0.01294, 0.01453,
          0.00812, 0.009155],
         [0.0141, 0.01373, 0.007507, 0.01141, 0.015076, 0.01312,
          0.00934, 0.01251, 0.01599, 0.01123, 0.0105, 0.01288,
          0.01312, 0.009766, 0.01282, 0.01434, 0.006958, 0.01196,
          0.01367, 0.013916, 0.007538, 0.01398, 0.01141, 0.00989,
          0.00644, 0.01233, 0.01465, 0.01404, 0.01306, 0.01471,
          0.00824, 0.00922]], dtype=np.float16
    )

    output_only = np.array(
        [[0.007385, 0.00812, 0.006226, 0.006195, 0.00842, 0.006775,
          0.007416, 0.006927, 0.007355, 0.007782, 0.00763, 0.00656,
          0.008606, 0.00534, 0.007202, 0.006897, 0.007996, 0.00653,
          0.00806, 0.006714, 0.007446, 0.007446, 0.005737, 0.007477,
          0.006958, 0.00717, 0.006805, 0.006775, 0.005737, 0.0083,
          0.007202, 0.00775],
         [0.007812, 0.008484, 0.00659, 0.00653, 0.00897, 0.00714,
          0.007812, 0.007416, 0.00775, 0.00818, 0.00812, 0.006866,
          0.00903, 0.005768, 0.00769, 0.007355, 0.00836, 0.006958,
          0.008484, 0.00711, 0.007935, 0.007935, 0.006073, 0.007812,
          0.007355, 0.00757, 0.007324, 0.00714, 0.006134, 0.00873,
          0.0076, 0.00836],
         [0.00589, 0.00647, 0.004944, 0.004974, 0.006775, 0.00537,
          0.00586, 0.005554, 0.0058, 0.006256, 0.006134, 0.005188,
          0.006836, 0.004333, 0.005737, 0.005493, 0.006348, 0.005188,
          0.006348, 0.00531, 0.00592, 0.00592, 0.004547, 0.00592,
          0.005585, 0.005737, 0.005493, 0.00537, 0.00467, 0.00656,
          0.005707, 0.006226],
         [0.00592, 0.006622, 0.005066, 0.005127, 0.006927, 0.005493,
          0.006012, 0.005768, 0.00589, 0.006317, 0.006287, 0.00522,
          0.006897, 0.004364, 0.0058, 0.005646, 0.00641, 0.00534,
          0.00641, 0.0054, 0.00598, 0.006165, 0.004517, 0.00598,
          0.005646, 0.00589, 0.005585, 0.0054, 0.00476, 0.006714,
          0.0058, 0.006287]], dtype=np.float16
    )

    return {
        "has_bias_gate_output": has_bias_gate_output,
        "has_gate_output": has_gate_output,
        "has_bias_output": has_bias_output,
        "output_only": output_only,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_data()
